import torch
import numpy as np

from dada.optimizer import WDA, UGM, DADA, DoG, Prodigy

from dada.model import ModelRunner
from dada.model.worst_instance import WorstCaseModel
from dada.utils import Param, run_different_opts, plot_optimizers_result


class WorstCaseRunner(ModelRunner):

    def __init__(self, params: dict):
        self.q_list = params['q_list']
        if self.q_list is None:
            raise ValueError('q_list is None')

        super(WorstCaseRunner, self).__init__(params)

    def run(self, iterations, model_name, save_plot, plots_directory):
        params = [
            Param(names=["p", "d"], values=[q, self.vector_size])
            for q in self.q_list
        ]
        value_distances_per_param = {}
        d_estimation_error_per_param = {}

        optimal_point = np.zeros((self.vector_size,))

        optimizers = []

        for param in params:
            print(param)
            p = param.get_param("p")
            d = param.get_param("d")
            init = torch.ones(d, requires_grad=True, dtype=torch.double)
            d0 = np.linalg.norm(optimal_point - init.clone().detach().numpy())

            # Dual Averaging Method
            da_model = WorstCaseModel(d, p, init_point=init)
            da_optimizer = WDA(da_model.params(), d0=d0)

            # GD With Line Search Method
            gd_line_search_model = WorstCaseModel(d, p, init_point=init)
            gd_line_search_optimizer = UGM(gd_line_search_model.params())

            # DoG Method
            dog_model = WorstCaseModel(d, p, init_point=init)
            dog_optimizer = DoG(dog_model.params())

            # Prodigy Method
            prodigy_model = WorstCaseModel(d, p, init_point=init)
            prodigy_optimizer = Prodigy(prodigy_model.params())

            # DADA Method
            dada_model = WorstCaseModel(d, p, init_point=init)
            dada_optimizer = DADA(dada_model.params())

            optimizers = [
                [da_optimizer, da_model],
                [gd_line_search_optimizer, gd_line_search_model],
                [dog_optimizer, dog_model],
                [prodigy_optimizer, prodigy_model],
                [dada_optimizer, dada_model]
            ]

            d_estimation_error, value_distances = run_different_opts(optimizers, iterations, optimal_point, log_per=100)
            value_distances_per_param[param] = value_distances
            d_estimation_error_per_param[param] = d_estimation_error

        plot_optimizers_result(optimizers, params, value_distances_per_param, d_estimation_error_per_param,
                               model_name=model_name, save=save_plot, plots_directory=plots_directory,
                               mark_every=(iterations // 10))
